package godb

import (
	"database/sql"
	"fmt"
	sqlserver "github.com/denisenkom/go-mssqldb"
	"github.com/sirupsen/logrus"
	"github.com/spf13/cast"
	"io"
	"strings"
	"sync"
)

// Float64ToString 浮点数转换字符串,length小数点后面几位,默认是3位
func float64ToString(f float64) string {
	fmtStr := fmt.Sprintf("%%.%df", 4)
	ret := strings.TrimRight(fmt.Sprintf(fmtStr, f), "0")
	return strings.TrimRight(ret, ".")
}

func valString(i interface{}) string {
	switch s := i.(type) {
	case float64:
		return float64ToString(s)
	case float32:
		return float64ToString(float64(s))
	default:
		return strings.TrimSpace(cast.ToString(i))
	}
}

// NewRowsResult 读取查询结果
func NewRowsResult(rows *sql.Rows, sql string, args []interface{}) *QueryResult {
	ret := &QueryResult{
		sql:  sql,
		args: args,
	}
	if rows == nil {
		ret.columns = []string{}
		ret.data = nil
	} else {
		var err error
		ret.columns, err = rows.Columns()
		if err != nil {
			ret.err = err
		} else {
			columnTypes, e := rows.ColumnTypes()
			if e == nil {
				ret.columnTypes = make([]string, len(columnTypes))
				for i, v := range columnTypes {
					ret.columnTypes[i] = v.DatabaseTypeName()
				}
			}
			ret.rows = rows
		}
		ret.passRows() //避免忘记关闭查询结果
	}
	return ret
}

// ErrQueryResult 返回一个查询错误
func ErrQueryResult(err error, db string, sql string, args []interface{}) *QueryResult {
	if log != nil {
		log.WithFields(logrus.Fields{
			"db":     db,
			"sql":    sql,
			"params": args,
		}).WithError(err).Error("SQL错误")
	}
	return &QueryResult{
		sql:  sql,
		args: args,
		err:  err,
	}
}

type QueryResult struct {
	columns     []string        //查询字段内容
	columnTypes []string        //字段数据类型
	data        [][]interface{} //查询结果内容
	datalength  int             //结果长度
	rows        *sql.Rows       //查询结果对象
	err         error           //查询错误
	sql         string          //查询的sql
	args        []interface{}   //查询参数
}

// ErrorToLog 错误保存到日志
func (r *QueryResult) ErrorToLog(log *logrus.Entry, msg ...string) *QueryResult {
	if r.err != nil && log != nil {
		lg := log.WithFields(logrus.Fields{
			"sql":    r.sql,
			"params": r.args,
		}).WithError(r.err)
		if len(msg) > 0 {
			var msgs = []interface{}{"SQL错误:"}
			for _, v := range msg {
				msgs = append(msgs, v)
			}
			lg.Error(msgs...)
		} else {
			lg.Error("SQL错误")
		}
	}
	return r
}

func (r *QueryResult) Error() error {
	return r.err
}

func (r *QueryResult) IsEmpty() bool {
	return r.datalength < 1
}

// 解析查询结果
func (r *QueryResult) passRows() {
	if r.rows != nil {
		r.data = make([][]interface{}, 0)
		columnTypes, _ := r.rows.ColumnTypes()
		var uniqueidentifierIndexs []int //mssql UNIQUEIDENTIFIER类型数据坐标ID
		for i, v := range columnTypes {
			if v.DatabaseTypeName() == "UNIQUEIDENTIFIER" {
				uniqueidentifierIndexs = append(uniqueidentifierIndexs, i)
			}
		}
		for r.rows.Next() {
			row := make([]interface{}, len(r.columns))
			for i := range row {
				var ref interface{}
				row[i] = &ref
			}
			err := r.rows.Scan(row...)
			if err != nil {
				r.datalength = 0
				r.rows.Close()
				r.rows = nil
				r.err = err
				return
			}
			for k, v := range row {
				row[k] = *v.(*interface{})
			}
			for _, index := range uniqueidentifierIndexs {
				value := row[index]
				if value != nil {
					i := sqlserver.UniqueIdentifier{}
					e := i.Scan(value)
					if e == nil {
						row[index] = i.String()
					}
				}
			}
			r.data = append(r.data, row)
		}
		r.datalength = len(r.data)
		r.rows.Close()
		r.err = r.rows.Err() //获取到结果错误
		if r.err != nil && log != nil {
			log.WithFields(logrus.Fields{
				"sql":    r.sql,
				"params": r.args,
			}).WithError(r.err).Error("SQL错误")
		}
		r.rows = nil
	}
}

// Get 读取某行的指定字段值.columnName表示字段名称，index表示第几行默认第一行，如果结果不存在返回nil
func (r *QueryResult) Get(columnName string, index ...int) interface{} {
	if len(index) < 1 {
		index = []int{0}
	}
	if index[0] >= r.datalength { //超出数据返回nil
		return nil
	}
	for i, v := range r.columns {
		if v == columnName {
			return r.data[index[0]][i]
		}
	}
	return nil
}

// GetMap 读取某行的所有数据.
// index代表第几行默认第一行，返回的map中key是数据字段名称，value是值
func (r *QueryResult) GetMap(index ...int) map[string]interface{} {
	if len(index) < 1 {
		index = []int{0}
	}
	if index[0] >= r.datalength {
		return nil
	}
	ret := make(map[string]interface{})
	for i, v := range r.columns {
		ret[v] = r.data[index[0]][i]
	}
	return ret
}

// Columns 获取字段列表
func (r *QueryResult) Columns() []string {
	return r.columns
}

// ColumnTypes 字段类型
func (r *QueryResult) ColumnTypes() []string {
	return r.columnTypes
}

// Rows 获取所有数据
func (r *QueryResult) Rows() [][]interface{} {
	return r.data
}

// GetStringMap 获取string类型map
func (r *QueryResult) GetStringMap(index ...int) map[string]string {
	if len(index) < 1 {
		index = []int{0}
	}
	if index[0] >= r.datalength {
		return nil
	}
	ret := make(map[string]string)
	for i, v := range r.columns {
		ret[v] = valString(r.data[index[0]][i])
	}
	return ret
}

// GetCsvData 生成csv数据
func (r *QueryResult) GetCsvData() [][]string {
	var ret = make([][]string, r.datalength+1)
	ret[0] = r.columns
	for i, v := range r.data {
		ret[i+1] = make([]string, len(v))
		for a, k := range v {
			ret[i+1][a] = valString(k)
		}
	}
	return ret
}

// Length 获取结果长度
func (r *QueryResult) Length() int {
	return r.datalength
}

// SQLParams 获取sql及其参数
func (r *QueryResult) SQLParams() (string, []interface{}) {
	return r.sql, r.args
}

// ForEach 循环读取所有数据
// 返回的map中key是数据字段名称，value是值,回调函数中如果返回false则停止循环后续数据
func (r *QueryResult) ForEach(f func(map[string]interface{}) bool) *QueryResult {
	if f == nil {
		return r
	}
	if r.datalength < 1 { //没有数据结果直接返回
		return r
	}
	ret := map[string]interface{}{}
	for j, v := range r.data {
		if j >= r.datalength {
			break
		}
		for i, vv := range r.columns {
			ret[vv] = v[i]
		}
		if !f(ret) {
			break
		}
	}
	return r
}

// Iterator 获取记录游标
func (r *QueryResult) Iterator() RowIterator {
	return &rowDataIterator{
		data:    r.data,
		columns: r.columns,
		index:   0,
		row:     nil,
		lock:    sync.Mutex{},
	}
}

type rowDataIterator struct {
	data    [][]interface{}
	columns []string
	index   int
	row     map[string]interface{}
	lock    sync.Mutex
}

// Length 获取结果长度
func (r *rowDataIterator) Length() int {
	return len(r.data)
}

// HasNext 判断是否有下一条记录
func (r *rowDataIterator) HasNext() bool {
	r.lock.Lock()
	defer r.lock.Unlock()
	return len(r.data) > r.index
}

// Reset 重置游标
func (r *rowDataIterator) Reset() {
	r.lock.Lock()
	defer r.lock.Unlock()
	r.index = 0
}

// Next 取下一条数据
func (r *rowDataIterator) Next() (map[string]interface{}, error) {
	if len(r.data) < 1 || r.index >= len(r.data) {
		return nil, io.EOF
	}
	r.lock.Lock()
	defer r.lock.Unlock()
	if r.row == nil {
		r.row = map[string]interface{}{}
	}
	data := r.data[r.index]
	for i, vv := range r.columns {
		r.row[vv] = data[i]
	}
	r.index++
	return r.row, nil
}

// NewExecResult 获取一个操作结果对象
func NewExecResult(rs sql.Result, sql string, args []interface{}) *ExecResult {
	return &ExecResult{
		err:    nil,
		sql:    sql,
		args:   args,
		Result: rs,
	}
}

// ErrExecResult 查询错误结果
func ErrExecResult(err error, db string, sql string, args []interface{}) *ExecResult {
	if log != nil {
		log.WithFields(logrus.Fields{
			"db":     db,
			"sql":    sql,
			"params": args,
		}).WithError(err).Error("SQL错误")
	}
	return &ExecResult{
		sql:  sql,
		args: args,
		err:  err,
	}
}

type ExecResult struct {
	sql.Result
	sql  string
	args []interface{}
	err  error //查询错误
}

func (r *ExecResult) ErrorToLog(log *logrus.Entry, msg ...string) *ExecResult {
	if r.err != nil && log != nil {
		lg := log.WithField("sql", r.sql).
			WithField("data", r.args).
			WithError(r.err)
		if len(msg) > 0 {
			var msgs = []interface{}{"SQL错误:"}
			for _, v := range msg {
				msgs = append(msgs, v)
			}
			lg.Error(msgs...)
		} else {
			lg.Error("SQL错误")
		}
	}
	return r
}

func (r *ExecResult) Error(reportZeroChange ...bool) error {
	if r.err != nil {
		return r.err
	} else if len(reportZeroChange) < 1 {
		reportZeroChange = []bool{false}
	}
	changrow, _ := r.RowsAffected()
	if changrow == 0 && reportZeroChange[0] {
		return SQLEmptyChange
	}
	return nil
}

// SQLParams 获取sql及其参数
func (r *ExecResult) SQLParams() (string, []interface{}) {
	return r.sql, r.args
}

// LastInsertId returns the integer generated by the database
// in response to a command. Typically this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
func (r *ExecResult) LastInsertId() (int64, error) {
	if r.Result != nil {
		return r.Result.LastInsertId()
	}
	return 0, r.err
}

// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
func (r *ExecResult) RowsAffected() (int64, error) {
	if r.Result != nil {
		return r.Result.RowsAffected()
	}
	return 0, r.err
}
